import copy

import numpy as np
import pandas as pd

import libpyhat.clustering.cluster as cluster
import libpyhat.emi.emi as endmember_identify
import libpyhat.transform.cal_tran as cal_tran
import libpyhat.transform.cal_tran_cv as cal_tran_cv
import libpyhat.transform.deriv as deriv
import libpyhat.transform.dim_red as dim_red
import libpyhat.transform.interp as interp
import libpyhat.transform.mask as mask
import libpyhat.transform.multiply_vector as multiply_vector
import libpyhat.transform.norm as norm
import libpyhat.transform.peak_area as peak_area
import libpyhat.transform.remove_baseline as remove_baseline
import libpyhat.transform.scale as scale
import libpyhat.transform.shift_spect as shift_spect
import libpyhat.utils.folds as folds
import libpyhat.utils.outlier_identify as outlier_identify
from libpyhat.Unmixing import unmix
from libpyhat.utils.lookup import lookup
from libpyhat.utils.utils import enumerate_duplicates
from libpyhat.utils.utils import remove_rows


class SpectralData(object):
    """This class is the native object used to store spectral data in PyHAT.
    Image cubes, point spectra, etc. will be translated into this object and
    this object will be passed around to PyHAT functionalities. Where
    necessary,
    those functionalities will translate the class into the necessary formats
    for scikit-learn and other packages or functions according to their
    respective API/interfacing requirements.

    Parameters:
        object : a pandas dataframe that has a particular multi-index structure

    Notes: The structure of the pandas dataframe required by this class is
    as follows:

        |meta|meta|...|wvl|wvl|...|comp|comp|...
        |metadata_category|another_metadata_category
        ...|wavelength_value|wavelength_value|...|composition_category
        |composition_category|...}
      0 |val|val|...|val|val|...|val|val|...
      1 |val|val|...|val|val|...|val|val|...
      ...
      N |val|val|...|val|val|...|val|val|...

      Metadata categories can be strings, floats, ints, and have no expected
      or enforced datatypes. However, common practice is that these categories
      are strings, "target_name", "latitude [degrees]", etc.

      An attempt will be made to convert all level-two header values to
      floats. This
      process is expected to fail for non-numerical strings, like most metadata
      and composition categories. However, wavelength values are expected to
      be ints or
      floats. Failure in this particular conversion, such as if special
      characters are
      included, e.g. '<125', for intensities at wavelengths less than 125
      wavelength
      units, will result in an error. The rename column functionality in
      PyHAT can
      help the user address this after the class is instantiated.

      Composition category names can also be strings, floats, ints, and have no
      expected or enforced datatypes. However, common practice is that these
      categories are strings, such as "MnO [ppm]" and "Olivine [wt%]".

      Spectral intensities are expected to be numeric and an attempt to
      convert them
      to float will be made. Failure of this process will generate a warning
      message.
      This can happen when non-numeric value or non-numeric string was
      present, such
      as when the intensity is reads, '<12.5' or '~12.5'. The user can use
      class features
      to convert these intensities to numerical values.

      To-do:
        Introduce the class feature to convert composition or spectral
        intensity values
        to a numerical value of the user's choice,
        e.g ConvertLessThanToValue(data, value='0').

      To-do:
        Introduce this functionality: The indexes in the first column can be
        provided
        by the user, but if missing, will be assigned. They will be enforced
        to start
        from 0 and count up to the number of spectra, N.

      To-do:
        Explicit handling of identical combinations of 1st and 2nd level
        columns, whether
        tuples or otherwise, e.g. ('meta','target_type') and ('meta',
        'target_type') both
        being in the same dataset.

      To-do: We need to handle the case where the columns are not tuples,
      nor in the
      native format.

      Spectra datasets do not need *all* three expected top-level column
      headers
      ('wvl', 'meta', 'comp'), but one of them needs to be present. If none
      are present,
      an exception is thrown and interrupt class instantiation.

      If the user provides data for a column type that is not in the
      expected list,
      this data will be dropped.

      The user has the ability to set the required top-level columns (see
      __init__ args),
      however certain PyHAT functionalities expect the presence of certain
      columns.
    """

    def __init__(
        self,
        df,
        name=None,
        meta_label="meta",
        spect_label="wvl",
        comp_label="comp",
        geodata=None,
    ):
        self.name = name
        self.geodata = geodata  # this keyword lets us carry geodata info
        # along if we are working with an orbital cube
        self.meta_label = meta_label
        self.spect_label = spect_label
        self.comp_label = comp_label

        top_level_columns = []
        if meta_label is not None:
            try:
                df[meta_label]
                top_level_columns.append(meta_label)
            except:
                print(
                    "The specified metadata label (" + meta_label + ") was "
                    "not "
                    "found "
                    "in the "
                    "data "
                    "frame!"
                )
                print("Setting meta_label to None")
                self.meta_label = None
        if comp_label is not None:
            try:
                df[comp_label]
                top_level_columns.append(comp_label)
            except:
                print(
                    "The specified composition label (" + comp_label + ") "
                    "was "
                    "not "
                    "found in the data frame!"
                )
                print("Setting comp_label to None")
                self.comp_label = None
        if spect_label is not None:
            try:
                df[spect_label]
                top_level_columns.append(spect_label)
            except:
                print(
                    "The specified spectral data label (" + spect_label + ") "
                    "was not found in the data frame!"
                )
                print("Setting spect_label to None")
                self.spect_label = None

        # Check to make sure that at least one of the expected top level
        # columns
        # are present. If not, raise exception.
        if top_level_columns == []:
            raise Exception(
                "ERROR: The data frame does not contain columns with the "
                "labels "
                + meta_label
                + ", "
                + comp_label
                + ", or "
                + spect_label
                + "! Check your inputs and try again."
            )

        # Attempt to get the top level and second level column names
        try:
            df.columns.levels[0]
            list(df.columns.levels[1].values)
        # If the columns are not multiindexes, then we will *assume* that
        # they are tuples that can be converted. Anything that isn't a tuple
        # will be removed. An example of this would be a dataset that looks
        # like this:
        #    |('meta','target_name') | ('wvl',125.5)
        # 0  | 'Made_Up_Name' | 12345.3
        except:
            print(
                "WARNING: The spectra dataset is not in PyHAT's native \
                            multi-index format. \
                            \nIt will be converted assuming column labels "
                "are tuples suitable for converstion to multi-index. \
          \nPlease check that this has been done correctly"
            )

            # Build list of tuples to drop
            to_drop = []

            # Loop through the columns
            for i in range(len(df.columns)):
                # Check if the ith column is a tuple
                if not isinstance(df.columns[i], tuple):
                    # If not, add the column to the drop list
                    print(
                        "WARNING: "
                        + str(df.columns[i])
                        + " is not a tuple \
                    (this can be caused by duplicate column names). Removing "
                        "this column."
                    )
                    to_drop.append(df.columns[i])

                # If the ith column is a tuple, check to see if the top-level
                # value is in the default list of top level columns
                elif df.columns[i][0] not in top_level_columns:
                    print(
                        "WARNING: You have provided data with a top-level \
                                                  column %s that does not "
                        "match the specified "
                        "top-level column: \
                        %s. *This data will be dropped.* You can either "
                        "reformat\
                        your dataset or change the top-level column "
                        "labels." % (df.columns[i][0], top_level_columns)
                    )
                    to_drop.append(df.columns[i])

            # Drop the list of columns that are not tuples
            df.drop(columns=to_drop, inplace=True)

            # Now we can generate the multiindex columns
            df.columns = pd.MultiIndex.from_tuples(list(df.columns))

        new_columns = []
        for col in df.columns:
            col = list(col)
            # Try to turn the second level column headers into floats if
            # possible, which is relevant
            # to wavelength values
            try:
                col[1] = float(col[1])
            # This will generally fail for metadata categories since these are
            # usually non-numeric character strings there.
            except:
                if col[0] == self.spect_label:
                    # If this fails for a wavelength column, let the user know.
                    # This can cause issues with PyHAT analyses
                    print(
                        "WARNING: The wavelength value " + str(col[1]) + " failed to "
                        "be converted to a float. This may be caused by duplicate wavelengths. The value will "
                        "be kept as-is at this step but will be removed if remove_duplicates() is run,"
                        " such as during a call to LoadData. If this column should be kept, "
                        " this should be addressed by the user."
                    )
            new_columns.append(tuple(col))

        # Set the columns to their formatted versions
        df.columns = pd.MultiIndex.from_tuples(new_columns)

        # Try to convert spectral intensities to float
        if self.spect_label is not None:
            try:
                df[spect_label] = df[spect_label].apply(pd.to_numeric, errors="raise")
            except:
                print(
                    "WARNING: There are spectral intensities that are "
                    "non-numeric. These have failed float conversion and "
                    "could impact analysis."
                )
                df[spect_label] = df[spect_label].apply(pd.to_numeric, errors="ignore")

        # store the df in the object
        self.df = df

        if self.spect_label is not None:
            self.get_wvls()
        else:
            self.wvls = None

    def get_wvls(self):
        self.wvls = self.df[self.spect_label].columns.values

    def cal_tran(self, A, B, dataAmatchcol, dataBmatchcol, params, Aname, Bname):
        self.df, self.ct_obj = cal_tran.call_cal_tran(
            A,
            B,
            self.df,
            dataAmatchcol,
            dataBmatchcol,
            params,
            spect_label=self.spect_label,
            dataAname=Aname,
            dataBname=Bname,
            dataCname=self.name,
        )

    def cal_tran_cv(self, B, dataAmatchcol, dataBmatchcol, paramgrid, Bname):
        self.ct_cv_results = cal_tran_cv.call_cal_tran_cv(
            self.df,
            B,
            dataAmatchcol,
            dataBmatchcol,
            paramgrid,
            spect_label=self.spect_label,
            dataAname=self.name,
            dataBname=Bname,
        )

    def cluster(self, col, method, params, kws):
        self.df = cluster.cluster(self.df, col, method=method, params=params, kws=kws)

    def combine_spectral_data(self, data2):
        try:
            self.df[(self.meta_label, "Dataset")] #if there is already a dataset column, do nothing
        except:
            self.df[(self.meta_label, "Dataset")] = self.name

        try:
            data2.df[(self.meta_label, "Dataset")]
        except:
            data2.df[(self.meta_label, "Dataset")] = data2.name
        comp_labels = [self.comp_label,data2.comp_label]
        try:
            comp_label=comp_labels[comp_labels!=None]
        except:
            comp_label=None

        spect_labels = [self.spect_label, data2.spect_label]
        try:
            spect_label = spect_labels[spect_labels != None]
        except:
            spect_label = None

        meta_labels = [self.meta_label, data2.meta_label]
        try:
            meta_label = meta_labels[meta_labels != None]
        except:
            meta_label = None

        new_data = SpectralData(
            pd.concat([self.df, data2.df], ignore_index=True),
            meta_label=meta_label,
            spect_label=spect_label,
            comp_label=comp_label,
        )
        return new_data

    def copy_spectral_data(self, new_name):
        new_data = copy.deepcopy(self)
        new_data.name = new_name
        return new_data

    def deriv(self):
        self.df = deriv.deriv(self.df, spect_label=self.spect_label)

    def dim_red(self, col, method, params, kws, load_fit, ycol=None):
        self.df, self.do_dim_red = dim_red.dim_red(
            self.df, col, method, params, kws, load_fit=load_fit, ycol=ycol
        )

    def interp(self, xnew):
        self.df = interp.interp(self.df, xnew, spect_label=self.spect_label)

    def shift(self, shift):
        self.df = shift_spect.shift_spect(self.df, shift, spect_label=self.spect_label)

    def mask(self, maskfile, maskvar):
        self.df = mask.mask(self.df, maskfile, maskvar=maskvar)

    def multiply_vector(self, vectorfile):
        self.df = multiply_vector.multiply_vector(
            self.df, vectorfile=vectorfile, spect_label=self.spect_label
        )

    def norm(self, ranges, col_var):
        self.df = norm.norm(self.df, ranges, col_var=col_var)

    def outlier_identify(self, col, method, params):
        self.df = outlier_identify.outlier_identify(
            self.df, col=col, method=method, params=params
        )

    def endmember_identify(self, col, method, n_endmembers,meta_label='meta'):
        self.df, indices = endmember_identify.emi(
            self.df, col=col, emi_method=method, n_endmembers=n_endmembers, meta_label=meta_label
        )

    def peak_area(self, peaks_mins_file):
        self.df, self.peaks, self.mins = peak_area.peak_area(
            self.df, peaks_mins_file=peaks_mins_file, spect_label=self.spect_label
        )
        self.spect_label = "peak_area"

    def random_folds(self, nfolds):
        self.df = folds.random(self.df, nfolds, meta_label=self.meta_label)

    def remove_baseline(self, method, segment, params):
        self.df, self.df_baseline = remove_baseline.remove_baseline(
            self.df,
            method=method,
            segment=segment,
            params=params,
            spect_label=self.spect_label,
        )

    def stratified_folds(self, nfolds, col, tiebreaker, comp_label="comp",meta_label="meta"):
        if tiebreaker is None:
            pass
        else:
            tiebreaker=(comp_label,tiebreaker)
        self.df = folds.stratified_folds(
            self.df,
            nfolds=nfolds,
            sortby=(comp_label, col),
            tiebreaker=tiebreaker,
            meta_label=meta_label
        )

    def enumerate_duplicates(self, col):
        self.df = enumerate_duplicates(self.df, col=col)

    def scale(self, df_to_fit=None):
        self.df, self.scaler = scale.do_scale(
            self.df, df_to_fit, spect_label=self.spect_label
        )

    def unmix(self, endmembers_df, method, params, normalize, meta_label='meta'):
        endmember_cols = endmembers_df[meta_label].columns.values
        endmember_cols = (meta_label,endmember_cols[['Endmembers' in i for i in endmember_cols]][0])
        endmembers = endmembers_df.iloc[
            np.squeeze(np.array((endmembers_df[endmember_cols] != 'Not Endmember'))), :
        ]
        results = unmix.unmix(
            np.array(self.df[self.spect_label]),
            endmembers[self.spect_label],
            method,
            params=params,
            normalize=normalize,
        )
        return results

    def lookup(self, lookupdata, left_on, right_on):
        self.df = lookup(
            self.df, lookupdf=lookupdata, left_on=left_on, right_on=right_on
        )

    def remove_rows(self, matching_values):
        self.df = remove_rows(self.df, matching_values, spect_label=self.spect_label)

    def closest_wvl(self, input_wvls):
        wvls = self.df[self.spect_label].columns.values
        output_wvls = []
        for w in input_wvls:
            idx = (np.abs(wvls - w)).argmin()
            output_wvls.append(wvls[idx])
        return output_wvls

    def remove_unnamed(self):
        # Handle unnamed columns from the input data by removing them
        colmask = self.df.columns.levels[0].str.match("Unnamed")
        if np.max(colmask) > 0:
            print("Removing unnamed columns:")
            print(self.df.columns.levels[0][colmask])

            good_data = []
            for c in self.df.columns.levels[0][~colmask]:
                data_tmp = self.df[c]
                data_tmp.columns = pd.MultiIndex.from_tuples(
                    [(c, col) for col in data_tmp.columns.values]
                )
                good_data.append(data_tmp)
            self.df = pd.concat(good_data, axis=1)
        else:
            pass

    def remove_duplicates(self):
        try:
            # remove duplicate wvl values
            data_wvl = self.df[self.spect_label]
            data_no_wvl = self.df.drop(columns=self.spect_label)

            good_wvls = []
            for i in data_wvl.columns:
                try:
                    i = float(i)
                    good_wvls.append(True)
                except:
                    print("Removing duplicate column " + str(i))
                    good_wvls.append(False)

            data_wvl = data_wvl.iloc[:, good_wvls]
            data_wvl.columns = pd.MultiIndex.from_tuples(
                [(self.spect_label, float(i)) for i in data_wvl.columns]
            )
            self.df = pd.merge(data_no_wvl, data_wvl, left_index=True, right_index=True)
        except:
            pass

    """
    def m3_params(self, paramname=None):
        if paramname is not None:
            if paramname == "R540":
                m3.r540(self)
            elif paramname == "R750":
                m3.r750(self)
            elif paramname == "R1580":
                m3.r1580(self)
            elif paramname == "R2780":
                m3.r2780(self)
            elif paramname == "VISNIR":
                m3.visnir(self)
            elif paramname == "R950_750":
                m3.r950_750(self)
            elif paramname == "2um_Ratio":
                m3.twoum_ratio(self)
            elif paramname == "Thermal_Ratio":
                m3.thermal_ratio(self)
            elif paramname == "Vis_Slope":
                m3.visslope(self)
            elif paramname == "1um_Slope":
                m3.oneum_slope(self)
            elif paramname == "2um_Slope":
                m3.twoum_slope(self)
            elif paramname == "BD620":
                m3.bd620(self)
            elif paramname == "BD950":
                m3.bd950(self)
            elif paramname == "BD1050":
                m3.bd1050(self)
            elif paramname == "BD1250":
                m3.bd1250(self)
            elif paramname == "BD3000":
                m3.bd3000(self)
            elif paramname == "BD1900":
                m3.bd1900(self)
            elif paramname == "BD2300":
                m3.bd2300(self)
            elif paramname == "BDI1000":
                m3.bdi1000(self)
            elif paramname == "BDI2000":
                m3.bdi2000(self)
            elif paramname == "OLINDEX":
                m3.olindex(self)
            elif paramname == "1um_min":
                m3.oneum_min(self)
            elif paramname == "1um_FWHM":
                m3.oneum_fwhm(self)
            elif paramname == "1um_symmetry":
                m3.oneum_sym(self)
            elif paramname == "BD1um_ratio":
                m3.bd1um_ratio(self)
            elif paramname == "BD2um_ratio":
                m3.bd2um_ratio(self)
            else:
                print(paramname + " is not recognized as a M3 summary " "parameter!")
    def crism_params(self, paramname=None):
        if paramname is not None:
            if paramname == "R440":
                crism.r440(self)
            elif paramname == "R530":
                crism.r530(self)
            elif paramname == "R600":
                crism.r600(self)
            elif paramname == "R770":
                crism.r770(self)
            elif paramname == "R1080":
                crism.r1080(self)
            elif paramname == "R1300":
                crism.r1300(self)
            elif paramname == "R1330":
                crism.r1330(self)
            elif paramname == "R1506":
                crism.r1506(self)
            elif paramname == "R2529":
                crism.r2529(self)
            elif paramname == "R3920":
                crism.r3920(self)
            elif paramname == "Red/Blue Ratio":
                crism.rbr(self)
            elif paramname == "BD530":
                crism.bd530_2(self)
            elif paramname == "BD640":
                crism.bd640_2(self)
            elif paramname == "BD860":
                crism.bd860_2(self)
            elif paramname == "BD920":
                crism.bd920_2(self)
            elif paramname == "BD1300":
                crism.bd1300(self)
            elif paramname == "BD1400":
                crism.bd1400(self)
            elif paramname == "BD1435":
                crism.bd1435(self)
            elif paramname == "BD1500":
                crism.bd1500_2(self)
            elif paramname == "BD1750":
                crism.bd1750_2(self)
            elif paramname == "BD1900":
                crism.bd1900_2(self)
            elif paramname == "BD1900r2":
                crism.bd1900r2(self)
            elif paramname == "BD2190":
                crism.bd2190(self)
            elif paramname == "BD2190":
                crism.bd2190(self)
            elif paramname == "BD2100":
                crism.bd2100_2(self)
            elif paramname == "BD2165":
                crism.bd2165(self)
            elif paramname == "BD2210":
                crism.bd2210_2(self)
            elif paramname == "BD2230":
                crism.bd2230(self)
            elif paramname == "BD2250":
                crism.bd2250(self)
            elif paramname == "BD2265":
                crism.bd2265(self)
            elif paramname == "BD2290":
                crism.bd2290(self)
            elif paramname == "BD2355":
                crism.bd2355(self)
            elif paramname == "BD2500h":
                crism.bd2500h_2(self)
            elif paramname == "BD2600":
                crism.bd2600(self)
            elif paramname == "BD3000":
                crism.crism_bd3000(self)
            elif paramname == "BD3100":
                crism.bd3100(self)
            elif paramname == "BD3200":
                crism.bd3200(self)
            elif paramname == "BD3400":
                crism.bd3400_2(self)
            elif paramname == "BDI1000VIS":
                crism.bdi1000VIS(self)
            elif paramname == "BDI1000IR":
                crism.bdi1000IR(self)
            elif paramname == "BDI2000":
                crism.crism_bdi2000(self)
            elif paramname == "SH600":
                crism.sh600_2(self)
            elif paramname == "SH770":
                crism.sh770(self)
            elif paramname == "SINDEX2":
                crism.sindex2(self)
            elif paramname == "CINDEX2":
                crism.cindex2(self)
            elif paramname == "RPEAK1":
                crism.rpeak1(self)
            elif paramname == "OLINDEX3":
                crism.olivine_index3(self)
            elif paramname == "LCPINDEX2":
                crism.lcp_index2(self)
            elif paramname == "HCPINDEX2":
                crism.hcp_index2(self)
            elif paramname == "ISLOPE1":
                crism.islope1(self)
            elif paramname == "ICER1_2":
                crism.icer1_2(self)
            elif paramname == "DOUB2200H":
                crism.doub2200h(self)
            elif paramname == "MIN2200":
                crism.min2200(self)
            elif paramname == "D2200":
                crism.d2200(self)
            elif paramname == "MIN2250":
                crism.min2250(self)
            elif paramname == "D2300":
                crism.d2300(self)
            elif paramname == "MIN2295_2480":
                crism.min2295_2480(self)
            elif paramname == "MIN2345_2537":
                crism.min2345_2537(self)
            elif paramname == "IRR1":
                crism.irr1(self)
            elif paramname == "IRR2":
                crism.irr2(self)
            elif paramname == "IRR3":
                crism.irr3(self)
            else:
                print(paramname + " is not recognized as a CRISM summary " "parameter!")
    """

    # sometimes (e.g. when loading data) we end up with spectra containing
    # nans.
    # This removes any spectrum with a NaN
    def remove_empty_spectra(self):
        if self.spect_label is not None:
            nan_mask = self.df[self.spect_label].isna().any(axis=1)
            if np.max(nan_mask) is True:
                print(
                    str(np.sum(nan_mask)) + " spectra containing NaNs identified! These "
                    "will be removed."
                )
                self.df = self.df.iloc[np.array(~nan_mask), :]
